
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

def filereader(fname):
    infile = open(fname,'r')

    # DIMENSIONS
    temp = infile.readline() # dims
    temp = infile.readline() # header
    temp = infile.readline().split() # vals
    natom, nmol, nreact, nfix, nvary = int(temp[0]), int(temp[1]), int(temp[2]), int(temp[4]), int(temp[6])
    temp = infile.readline() # header
    temp = infile.readline().split() # vals
    nz = int(temp[2])

    # skip to reactions
    temp = infile.readline()
    while 'REACTIONS:' not in temp:
        temp = infile.readline()

    temp = infile.readline()
    rnames = []
    while 'CONCENTRATIONS' not in temp:
        # if termolecular k continuing, skip:
        if temp.strip().startswith('k'):
            temp = infile.readline()
            continue
        xx = temp[7:80]
        rnametemp = xx.replace(' ','')
        if '=' in rnametemp: # make sure it's a reaction, not a warning
            rnames += [rnametemp]
        temp = infile.readline()

    k = np.zeros((nz,len(rnames)))
    while 'KINETIC RATE CONSTANTS' not in temp:
        temp = infile.readline()
    blank = infile.readline()
    header = ['','','','','','','','','','','']
    while (not 'TIME STEP' in blank) and (len(header) > 10):
        header = infile.readline().split('K')
        try:
            do_nothing = header.remove('(\n')
        except:
            do_nothing = 0
        kindx = [int(i.replace("(","").replace(")","")) for i in header[1:]]
        for iz in range(0,nz):
            temp = infile.readline().split()
            for ix in range(0,len(kindx)):
                k[iz,kindx[ix]-1] = float(temp[ix+2])
        blank = infile.readline()
#        header = infile.readline().split('K')

    while 'PHOTOCHEMICAL RATE CONSTANTS:' not in temp:
        temp = infile.readline()
    blank = infile.readline()
    header = infile.readline()
    z = np.zeros(nz)
    while len(header.split()) > 1:
        xx = header.split('K')
        if (len(xx) < 11) | ('(\n' in xx) :
            kindx = [int(i.replace("(","").replace(")","")) for i in xx[1:-1]]
        else:
            kindx = [int(i.replace("(","").replace(")","")) for i in xx[1:]]
        for iz in range(0,nz):
            temp = infile.readline().split()
            for ix in range(0,len(kindx)):
                k[iz,kindx[ix]-1] = float(temp[ix+2])
                z[iz] = float(temp[1])
        blank = infile.readline()
        header = infile.readline()

    rates = np.zeros((nz,len(rnames)))
    while 'REACTION RATES:' not in temp:
        temp = infile.readline()
    blank = infile.readline()
    header = infile.readline()
    while len(header.split())>1:
        xx = header.split("RATE")
        if len(xx) < 11:
            kindx = [int(i.replace("(","").replace(")","")) for i in xx[1:-1]]
        else:
            kindx = [int(i.replace("(","").replace(")","")) for i in xx[1:]]
        for iz in range(0,nz):
            temp = infile.readline().split()
            for ix in range(0,len(kindx)):
                rates[iz,kindx[ix]-1] = float(temp[ix+2])
        blank = infile.readline()
        col = infile.readline()
        blank = infile.readline()
        header = infile.readline()

    dat = {'rnames':rnames,'k':k,'rates':rates,'z':z}
    return dat

# open user file
infile = open('plotks_input.dat','r')
files = []
labels = []
ks, rates = [], []
temp = infile.readline()
print(temp)
while len(temp)> 1:
    xx = temp.split()
    if xx[0] == 'file':
        files += [xx[1]]
        labels += [xx[2]]
    elif xx[0] == 'k':
        ks += [xx[1]]
    elif xx[0] == 'rate':
        rates += [xx[1]]
    elif xx[0] == 'figure_path':
        figure_path = xx[1]
    else:
        print('INVALID INPUT: col. 1 must specify: file or rxn or figure_path')
        exit()
    temp = infile.readline()
    
if len(files) > 5:
    print('INVALID INPUT: can only compare up to five outputs at a time')
    exit()
    
dat = {}
for ifile in range(0,len(files)):
    dat[labels[ifile]] = filereader(files[ifile])

# Plot Ks    
lns = ['-','--','-.',':',(0, (1, 11))]
clr = ['red','blue','green','purple','orange','black','cyan','brown','magenta','pink','yellow']
if 'ALL' in ks:
    for irxn in range(0,len(dat[labels[0]]['rnames'])):
        fig,ax=plt.subplots(figsize=(6,6))
        for ifile in range(0,len(files)):
            try:
                rindx = dat[labels[ifile]]['rnames'].index(dat[labels[0]]['rnames'][irxn])
            except:
                print(dat[labels[0]]['rnames'][irxn], 'not found in', files[ifile])
                continue
            ax.semilogx(dat[labels[ifile]]['k'][:,rindx],dat[labels[ifile]]['z'],linestyle=lns[ifile],color=clr[ifile],linewidth=2,label=labels[ifile])
            ax.legend()
            ax.set_xlabel('reaction k')
            ax.set_title(dat[labels[0]]['rnames'][irxn])
        plt.savefig(figure_path+dat[labels[0]]['rnames'][irxn]+'_k.png',dpi=200)
else:
    for irxn in range(0,len(ks)):
        fig,ax=plt.subplots(figsize=(6,6))
        for ifile in range(0,len(files)):
            try:
                rindx = dat[labels[ifile]]['rnames'].index(ks[irxn])
            except:
                print(ks[irxn], 'not found in ', files[ifile])
            ax.semilogx(dat[labels[ifile]]['k'][:,rindx],dat[labels[ifile]]['z'],linestyle=lns[ifile],color=clr[ifile],linewidth=2,label=labels[ifile]) # can change 'semilogx' to plot 
            ax.legend()
            ax.set_xlabel('reaction k')
            ax.set_title(ks[irxn])
            #ax.set_xlim([1e-10,1e-6])
            #ax.set_ylim([0,300])
        plt.savefig(figure_path+ks[irxn]+'_k.png',dpi=200)

if 'ALL' in rates:
    print('here')
    for irxn in range(0,len(dat[labels[0]]['rnames'])):
        fig,ax=plt.subplots(figsize=(6,6))
        for ifile in range(0,len(files)):
            try:
                rindx = dat[labels[ifile]]['rnames'].index(dat[labels[0]]['rnames'][irxn])
            except:
                print(dat[labels[0]]['rnames'][irxn], 'not found in', files[ifile])
                continue
            ax.semilogx(dat[labels[ifile]]['rates'][:,rindx],dat[labels[ifile]]['z'], linestyle=lns[ifile],color=clr[ifile],linewidth=2,label=labels[ifile]) # can change 'semilogx' to plot  
            ax.legend()
            ax.set_xlabel('reaction rate')
            ax.set_title(rates[irxn])
        #ax.set_xlim([1e-10,1e-6]) # uncomment & change values 
        #ax.set_ylim([0,300])  # uncomment & change values
    plt.savefig(figure_path+dat[labels[0]]['rnames'][irxn]+'_rate.png',dpi=200)
else:
    for irxn in range(0,len(rates)):
        fig,ax=plt.subplots(figsize=(6,6))
        ax1=ax.twiny()
        for ifile in range(0,len(files)):
            rindx = dat[labels[ifile]]['rnames'].index(rates[irxn])
            ax.semilogx(dat[labels[ifile]]['rates'][:,rindx],dat[labels[ifile]]['z'],linestyle=lns[ifile],color=clr[ifile],linewidth=2,label=labels[ifile]) 
            if ifile==1:
                ax1.plot(dat[labels[ifile]]['rates'][:,rindx]/dat[labels[0]]['rates'][:,rindx],dat[labels[ifile]]['z'],linestyle='-',color='k',linewidth=2,label='ratio')
                ax1.set_xlim(0.0,4.5)
            ax.legend()
            ax.set_xlabel('reaction k')
            ax.set_ylabel('Altitude')
            ax.set_title(rates[irxn])
        plt.savefig(figure_path+rates[irxn]+'_rate.png',dpi=200)
